{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Neighborhood Components Analysis Illustration\n\n\nThis example illustrates a learned distance metric that maximizes\nthe nearest neighbors classification accuracy. It provides a visual\nrepresentation of this metric compared to the original point\nspace. Please refer to the `User Guide <nca>` for more information.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# License: BSD 3 clause\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import make_classification\nfrom sklearn.neighbors import NeighborhoodComponentsAnalysis\nfrom matplotlib import cm\nfrom sklearn.utils.fixes import logsumexp\n\nprint(__doc__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Original points\n---------------\nFirst we create a data set of 9 samples from 3 classes, and plot the points\nin the original space. For this example, we focus on the classification of\npoint no. 3. The thickness of a link between point no. 3 and another point\nis proportional to their distance.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X, y = make_classification(n_samples=9, n_features=2, n_informative=2,\n                           n_redundant=0, n_classes=3, n_clusters_per_class=1,\n                           class_sep=1.0, random_state=0)\n\nplt.figure(1)\nax = plt.gca()\nfor i in range(X.shape[0]):\n    ax.text(X[i, 0], X[i, 1], str(i), va='center', ha='center')\n    ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)\n\nax.set_title(\"Original points\")\nax.axes.get_xaxis().set_visible(False)\nax.axes.get_yaxis().set_visible(False)\nax.axis('equal')  # so that boundaries are displayed correctly as circles\n\n\ndef link_thickness_i(X, i):\n    diff_embedded = X[i] - X\n    dist_embedded = np.einsum('ij,ij->i', diff_embedded,\n                              diff_embedded)\n    dist_embedded[i] = np.inf\n\n    # compute exponentiated distances (use the log-sum-exp trick to\n    # avoid numerical instabilities\n    exp_dist_embedded = np.exp(-dist_embedded -\n                               logsumexp(-dist_embedded))\n    return exp_dist_embedded\n\n\ndef relate_point(X, i, ax):\n    pt_i = X[i]\n    for j, pt_j in enumerate(X):\n        thickness = link_thickness_i(X, i)\n        if i != j:\n            line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]])\n            ax.plot(*line, c=cm.Set1(y[j]),\n                    linewidth=5*thickness[j])\n\n\ni = 3\nrelate_point(X, i, ax)\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Learning an embedding\n---------------------\nWe use :class:`~sklearn.neighbors.NeighborhoodComponentsAnalysis` to learn an\nembedding and plot the points after the transformation. We then take the\nembedding and find the nearest neighbors.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=0)\nnca = nca.fit(X, y)\n\nplt.figure(2)\nax2 = plt.gca()\nX_embedded = nca.transform(X)\nrelate_point(X_embedded, i, ax2)\n\nfor i in range(len(X)):\n    ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i),\n             va='center', ha='center')\n    ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]),\n                alpha=0.4)\n\nax2.set_title(\"NCA embedding\")\nax2.axes.get_xaxis().set_visible(False)\nax2.axes.get_yaxis().set_visible(False)\nax2.axis('equal')\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}